package main

import (
	"bufio"
	"bytes"
	"cogpt/internal/cache"
	"cogpt/internal/cogpt"
	myErr "cogpt/internal/cogpterror"
	"cogpt/internal/config"
	"cogpt/internal/log"
	"cogpt/internal/proxy"
	"cogpt/internal/share"
	"encoding/json"
	"fmt"
	"io"
	"net/http"
	"strings"
	"time"

	"github.com/gin-gonic/gin"
)

// handlerRoot is the handler for `/` route.
func handlerRoot(c *gin.Context) {
	c.String(200, "Hi, it's CoGPT!")
}

// handlerHealth is the handler for `/health` route.
func handlerHealth(c *gin.Context) {
	c.JSON(
		http.StatusOK,
		gin.H{
			"status": "OK",
		},
	)
}

// handlerV1Chat is the handler for `/v1/chat/completions` route.
func handlerV1ChatCompletions(c *gin.Context) {
	var err error
	var myerr myErr.Error

	url := "https://api.githubcopilot.com/chat/completions"

	// get app token
	appToken, ok := share.GetRealToken(strings.TrimPrefix(c.GetHeader("Authorization"), "Bearer "))

	if !ok {
		log.Logger.Error().Str("type", "handlerV1ChatCompletions").Msg("Bad authorization header")
		abortWithError(c, http.StatusUnauthorized, "Bad token")
		return
	}

	// parse request body
	requestBody := cogpt.CompletionsRequest{
		Model:       "gpt-3.5-turbo",
		Temperature: 0.5,
		Top_p:       1,
		N:           1,
		Stream:      false,
	}
	err = c.ShouldBindJSON(&requestBody)
	if err != nil {
		log.Logger.Error().Str("type", "handlerV1ChatCompletions").Msgf("Failed to parse request body: %s", err.Error())
		abortWithError(c, http.StatusBadRequest, err.Error())
		return
	}
	// check message
	if len(requestBody.Messages) == 0 {
		abortWithError(c, http.StatusBadRequest, "messages field is required")
		return
	}

	// get copilot token
	copilotToken, myerr := cogpt.GetCopilotToken(appToken)
	if myerr != nil {
		switch code := myerr.Code(); code {
		case myErr.ERR_FAILED_TO_CREATE_HTTP_CLIENT, myErr.ERR_FAILED_TO_FINISH_HTTP_REQUEST, myErr.ERR_FAILED_TO_PARSE_JSON:
			log.Logger.Error().Str("type", "handlerV1ChatCompletions").Msgf("Failed to create http client: %s", myerr.Error())
			abortWithError(c, http.StatusInternalServerError, myerr.Error())
			return

		case myErr.ERR_FAILED_TO_READ_HTTP_RESPONSE_BODY:
			log.Logger.Error().Str("type", "handlerV1ChatCompletions").Msgf("Failed to read http response body: %s", myerr.Error())
			abortWithError(c, http.StatusGatewayTimeout, myerr.Error())
			return

		case myErr.ERR_NO_COPILOT_ACCESS:
			log.Logger.Error().Str("type", "handlerV1ChatCompletions").Msg("No copilot access")
			abortWithError(c, http.StatusForbidden, "You don't have access to GitHub Copilot. Please sign up first.")
			return

		case myErr.ERR_FAILED_TO_GET_COPILOT_TOKEN:
			log.Logger.Error().Str("type", "handlerV1ChatCompletions").Msg("Failed to get copilot token")
			abortWithError(c, http.StatusBadGateway, myerr.Error())
			return

		default:
			log.Logger.Error().Str("type", "handlerV1ChatCompletions").Msgf("Unknown error: %s", myerr.Error())
			abortWithError(c, http.StatusInternalServerError, myerr.Error())
			return
		}
	}

	// get machineID and sessionID
	item, myerr := cache.CacheInstance.Get(appToken)
	if myerr != nil {
		log.Logger.Error().Str("type", "handlerV1ChatCompletions").Msg("Failed to get machineID and sessionID")
		abortWithError(c, http.StatusInternalServerError, myerr.Error())
		return
	}
	machineID := item.Vscode_machineid
	sessionID := item.Vscode_sessionid

	// generate headers
	headers := cogpt.GenHeaders(copilotToken, machineID, sessionID, requestBody.Stream)

	// send request
	client, myerr := proxy.GenProxyClient(config.ConfigInstance.Proxy, 0)
	if myerr != nil {
		log.Logger.Error().Str("type", "handlerV1ChatCompletions").Msgf("Failed to create http client: %s", myerr.Error())
		abortWithError(c, http.StatusInternalServerError, myerr.Error())
		return
	}
	requestBodyJSON, err := json.Marshal(requestBody)
	if err != nil {
		log.Logger.Error().Str("type", "handlerV1ChatCompletions").Msgf("Failed to marshal request body: %s", err.Error())
		abortWithError(c, http.StatusInternalServerError, err.Error())
		return
	}
	req, err := http.NewRequest("POST", url, bytes.NewReader(requestBodyJSON))
	if err != nil {
		log.Logger.Error().Str("type", "handlerV1ChatCompletions").Msgf("Failed to create http request: %s", err.Error())
		abortWithError(c, http.StatusInternalServerError, err.Error())
		return
	}
	// set headers
	for key, value := range headers {
		req.Header.Set(key, value)
	}
	// send request
	resp, err := client.Do(req)
	if err != nil {
		log.Logger.Error().Str("type", "handlerV1ChatCompletions").Msgf("Failed to send http request: %s", err.Error())
		abortWithError(c, http.StatusInternalServerError, err.Error())
		return
	}
	defer resp.Body.Close()

	// check response status code
	if resp.StatusCode != 200 {
		var rawRequestBody []byte
		var respBody []byte
		// read request body and response body
		rawRequestBody, err = io.ReadAll(req.Body)
		if err == nil {
			respBody, err = io.ReadAll(resp.Body)
		}
		if err != nil {
			log.Logger.Error().Str("type", "handlerV1ChatCompletions").Msgf("Failed to read response body: %s", err.Error())
			abortWithError(c, http.StatusInternalServerError, err.Error())
			return
		}
		log.Logger.Warn().Str("type", "handlerV1ChatCompletions").
			Str("status", resp.Status).
			Str("rawRequestBody", string(rawRequestBody)).
			Str("responseBody", string(respBody)).
			Msg("Status code is not 200")
		abortWithError(c, resp.StatusCode, string(respBody))
		return
	}

	// read response body and stream
	// set response headers
	c.Header("Transfer-Encoding", "chunked")
	c.Header("X-Accel-Buffering", "no")
	if requestBody.Stream {
		c.Header("Content-Type", "text/event-stream; charset=utf-8")
	} else {
		c.Header("Content-Type", "application/json; charset=utf-8")
	}
	c.Header("Cache-Control", "no-cache")
	c.Header("Connection", "keep-alive")
	// scan response body line by line
	scanner := bufio.NewScanner(resp.Body)
	for scanner.Scan() {
		line := scanner.Bytes()

		var object string
		if requestBody.Stream {
			object = "chat.completions.chunk"
		} else {
			object = "chat.completions"
		}

		if len(line) > 0 && !bytes.Contains(line, []byte("data: [DONE]")) {
			strData := strings.TrimPrefix(string(line), "data: ")
			data := cogpt.CompletionsResponse{}
			if err := json.Unmarshal([]byte(strData), &data); err != nil {
				log.Logger.Error().Str("type", "handlerV1ChatCompletions").Msgf("Failed to unmarshal response body: %s", err.Error())
				abortWithError(c, http.StatusInternalServerError, err.Error())
				return
			}
			if len(data.Choices) == 0 {
				continue
			}
			if data.Object == "" {
				data.Object = object
			}
			if data.Model == "" {
				data.Model = requestBody.Model
			}
			if data.Created == 0 {
				data.Created = time.Now().Unix()
			}

			newLine, err := json.Marshal(data)
			if err != nil {
				log.Logger.Error().Str("type", "handlerV1ChatCompletions").Msgf("Failed to marshal response body: %s", err.Error())
				abortWithError(c, http.StatusInternalServerError, err.Error())
				return
			}
			if requestBody.Stream {
				line = []byte(fmt.Sprintf("data: %s\n\n", string(newLine)))
			} else {
				line = newLine
			}
		}

		c.Writer.Write(line)
		c.Writer.Flush()
	}

	if err := scanner.Err(); err != nil {
		log.Logger.Error().Str("type", "handlerV1ChatCompletions").Msgf("Failed to scan response body: %s", err.Error())
		abortWithError(c, http.StatusBadGateway, err.Error())
		return
	}
}

// handlerV1Embeddings is the handler for `/v1/embeddings` route.
func handlerV1Embeddings(c *gin.Context) {
	var err error
	var myerr myErr.Error

	url := "https://api.githubcopilot.com/embeddings"

	// get app token
	appToken, ok := share.GetRealToken(strings.TrimPrefix(c.GetHeader("Authorization"), "Bearer "))

	if !ok {
		log.Logger.Error().Str("type", "handlerV1Embeddings").Msg("Bad authorization header")
		abortWithError(c, http.StatusUnauthorized, "Bad token")
		return
	}

	// parse request body
	requestBody := cogpt.EmbeddingsRequest{
		Model: "text-embedding-ada-002",
	}
	err = c.ShouldBindJSON(&requestBody)
	if err != nil {
		log.Logger.Error().Str("type", "handlerV1Embeddings").Msgf("Failed to parse request body: %s", err.Error())
		abortWithError(c, http.StatusBadRequest, err.Error())
		return
	}
	// input
	if requestBody.Input == nil {
		abortWithError(c, http.StatusBadRequest, "input field is required")
		return
	}
	if _, ok := requestBody.Input.([]any); !ok {
		requestBody.Input = []any{requestBody.Input}
	}

	// get copilot token
	copilotToken, myerr := cogpt.GetCopilotToken(appToken)
	if myerr != nil {
		switch code := myerr.Code(); code {
		case myErr.ERR_FAILED_TO_CREATE_HTTP_CLIENT, myErr.ERR_FAILED_TO_FINISH_HTTP_REQUEST, myErr.ERR_FAILED_TO_PARSE_JSON:
			log.Logger.Error().Str("type", "handlerV1Embeddings").Msgf("Failed to create http client: %s", myerr.Error())
			abortWithError(c, http.StatusInternalServerError, myerr.Error())
			return

		case myErr.ERR_FAILED_TO_READ_HTTP_RESPONSE_BODY:
			log.Logger.Error().Str("type", "handlerV1Embeddings").Msgf("Failed to read http response body: %s", myerr.Error())
			abortWithError(c, http.StatusGatewayTimeout, myerr.Error())
			return

		case myErr.ERR_NO_COPILOT_ACCESS:
			log.Logger.Error().Str("type", "handlerV1Embeddings").Msg("No copilot access")
			abortWithError(c, http.StatusForbidden, "You don't have access to GitHub Copilot. Please sign up first.")
			return

		case myErr.ERR_FAILED_TO_GET_COPILOT_TOKEN:
			log.Logger.Error().Str("type", "handlerV1Embeddings").Msg("Failed to get copilot token")
			abortWithError(c, http.StatusBadGateway, myerr.Error())
			return

		default:
			log.Logger.Error().Str("type", "handlerV1Embeddings").Msgf("Unknown error: %s", myerr.Error())
			abortWithError(c, http.StatusInternalServerError, myerr.Error())
			return
		}
	}

	// get machineID and sessionID
	item, myerr := cache.CacheInstance.Get(appToken)
	if myerr != nil {
		log.Logger.Error().Str("type", "handlerV1Embeddings").Msg("Failed to get machineID and sessionID")
		abortWithError(c, http.StatusInternalServerError, myerr.Error())
		return
	}
	machineID := item.Vscode_machineid
	sessionID := item.Vscode_sessionid

	// generate headers
	headers := cogpt.GenHeaders(copilotToken, machineID, sessionID, false)

	// send request
	client, myerr := proxy.GenProxyClient(config.ConfigInstance.Proxy, 0)
	if myerr != nil {
		log.Logger.Error().Str("type", "handlerV1Embeddings").Msgf("Failed to create http client: %s", myerr.Error())
		abortWithError(c, http.StatusInternalServerError, myerr.Error())
		return
	}
	requestBodyJSON, err := json.Marshal(requestBody)
	if err != nil {
		log.Logger.Error().Str("type", "handlerV1Embeddings").Msgf("Failed to marshal request body: %s", err.Error())
		abortWithError(c, http.StatusInternalServerError, err.Error())
		return
	}
	req, err := http.NewRequest("POST", url, bytes.NewReader(requestBodyJSON))
	if err != nil {
		log.Logger.Error().Str("type", "handlerV1Embeddings").Msgf("Failed to create http request: %s", err.Error())
		abortWithError(c, http.StatusInternalServerError, err.Error())
		return
	}
	// set headers
	for key, value := range headers {
		req.Header.Set(key, value)
	}
	// send request
	resp, err := client.Do(req)
	if err != nil {
		log.Logger.Error().Str("type", "handlerV1Embeddings").Msgf("Failed to send http request: %s", err.Error())
		abortWithError(c, http.StatusInternalServerError, err.Error())
		return
	}
	defer resp.Body.Close()

	// check response status code
	if resp.StatusCode != 200 {
		var rawRequestBody []byte
		var respBody []byte
		// read request body and response body
		rawRequestBody, err = io.ReadAll(req.Body)
		if err == nil {
			respBody, err = io.ReadAll(resp.Body)
		}
		if err != nil {
			log.Logger.Error().Str("type", "handlerV1Embeddings").Msgf("Failed to read response body: %s", err.Error())
			abortWithError(c, http.StatusInternalServerError, err.Error())
			return
		}
		log.Logger.Warn().Str("type", "handlerV1Embeddings").
			Str("status", resp.Status).
			Str("rawRequestBody", string(rawRequestBody)).
			Str("responseBody", string(respBody)).
			Msg("Status code is not 200")
		abortWithError(c, resp.StatusCode, string(respBody))
		return
	}

	// read response body and stream
	// set response headers
	c.Header("Transfer-Encoding", "chunked")
	c.Header("X-Accel-Buffering", "no")
	c.Header("Content-Type", "application/json; charset=utf-8")
	c.Header("Cache-Control", "no-cache")
	c.Header("Connection", "keep-alive")
	// scan response body line by line
	scanner := bufio.NewScanner(resp.Body)
	for scanner.Scan() {
		line := scanner.Bytes()

		if len(line) > 0 {
			data := cogpt.EmbeddingsResponse{}
			if err := json.Unmarshal(line, &data); err != nil {
				log.Logger.Error().Str("type", "handlerV1Embeddings").Msgf("Failed to unmarshal response body: %s", err.Error())
				abortWithError(c, http.StatusInternalServerError, err.Error())
				return
			}
			if len(data.Data) == 0 {
				continue
			}
			if data.Object == "" {
				data.Object = "list"
			}
			for i := 0; i < len(data.Data); i++ {
				if data.Data[i].Object == "" {
					data.Data[i].Object = "embedding"
				}
			}
			if data.Model == "" {
				data.Model = requestBody.Model
			}

			newLine, err := json.Marshal(data)
			if err != nil {
				log.Logger.Error().Str("type", "handlerV1Embeddings").Msgf("Failed to marshal response body: %s", err.Error())
				abortWithError(c, http.StatusInternalServerError, err.Error())
				return
			}
			line = newLine
		}

		c.Writer.Write(line)
		c.Writer.Flush()
	}

	if err := scanner.Err(); err != nil {
		log.Logger.Error().Str("type", "handlerV1Embeddings").Msgf("Failed to scan response body: %s", err.Error())
		abortWithError(c, http.StatusBadGateway, err.Error())
		return
	}
}

// handlerV1Models is the handler for `/v1/models` route.
func handlerV1Models(c *gin.Context) {
	c.JSON(
		http.StatusOK,
		gin.H{
			"object": "list",
			"data": []gin.H{
				cogpt.GenModel("gpt-3.5-turbo"),
				cogpt.GenModel("gpt-4"),
			},
		},
	)
}

// handlerRobots is the handler for `/robots.txt` route.
func handlerRobots(c *gin.Context) {
	c.Header("Content-Type", "text/plain")
	c.String(200, "User-agent: *\nDisallow: /")
}
